import base64

import numpy as np
import pandas as pd
from torch.utils.data import IterableDataset, get_worker_info

import torch
import torch.nn.functional as F
from torchvision.io import decode_image, ImageReadMode
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights
from torchvision.transforms import ConvertImageDtype


class ImageDataset(IterableDataset):
    def __init__(self, fp_imgs: str, model_type: str = "yolo_v3", img_size: int = 416):
        self.fp_imgs = fp_imgs
        self.model_type = model_type
        self.img_size = img_size

    def pad_to_square(self, img, pad_value=0):
        # https://github.com/mkocabas/yolov3-pytorch/blob/master/yolov3/utils/datasets.py
        c, h, w = img.shape
        dim_diff = np.abs(h - w)
        # (upper / left) padding and (lower / right) padding
        pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
        # Determine padding
        pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
        # Add padding
        return F.pad(img, pad, "constant", value=pad_value)

    def resize(self, img):
        # https://github.com/mkocabas/yolov3-pytorch/blob/master/yolov3/utils/datasets.py
        return F.interpolate(
            img.unsqueeze(0), size=self.img_size, mode="nearest"
        ).squeeze(0)

    def __iter__(self):
        worker_info = get_worker_info()
        with pd.read_json(
            self.fp_imgs, lines=True, chunksize=worker_info.num_workers
        ) as reader:

            for chunk in reader:
                img = chunk.iloc[worker_info.id]
                img_data = np.frombuffer(
                    base64.b64decode(img["img_b64"]), dtype="uint8"
                ).copy()
                img_data = torch.from_numpy(img_data)
                img_data = decode_image(img_data, mode=ImageReadMode.RGB)

                if self.model_type == "yolo_v3":
                    img_data = ConvertImageDtype(torch.float)(img_data)
                    img_data = self.pad_to_square(img_data)
                    img_data = self.resize(img_data)
                elif self.model_type == "mask_rcnn":
                    img_data = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()(
                        img_data
                    )

                yield img["qry_id"], img["img_rank"], img_data
